{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bd76759e",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fca44089",
   "metadata": {},
   "source": [
    "# Interpretability using dog and cat classification\n",
    "\n",
    "In this notebook, we do a binary classification task between photos of dogs and cats. This then enables us to explore MONAI's interpretability classes:\n",
    "\n",
    "- `OcclusionSensitivity`,\n",
    "- `GradCAM++`,\n",
    "- `SmoothGrad`,\n",
    "- `GuidedBackpropGrad`, and\n",
    "- `GuidedBackpropSmoothGrad`.\n",
    "\n",
    "We use a pre-trained Densenet, which enables us to do very quick training. For brevity, we also don't bother with splitting into training and validation datasets.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/interpretability/cats_and_dogs.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd47a584",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "978bbaf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n",
    "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
    "!python -c \"import sklearn\" || pip install -q scikit-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8c2f96d",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7e7829",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "import os\n",
    "from enum import Enum\n",
    "import torch\n",
    "from monai.config import print_config\n",
    "from monai.transforms import (\n",
    "    EnsureChannelFirstd,\n",
    "    Compose,\n",
    "    DivisiblePadd,\n",
    "    Lambdad,\n",
    "    LoadImaged,\n",
    "    Resized,\n",
    "    Rotate90d,\n",
    "    ScaleIntensityd,\n",
    ")\n",
    "from monai.networks.utils import eval_mode\n",
    "from contextlib import nullcontext\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib import colors\n",
    "from monai.data import Dataset, DataLoader\n",
    "from monai.networks.nets import DenseNet121\n",
    "from monai.data.utils import pad_list_data_collate\n",
    "from random import shuffle\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm, trange\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
    "from monai.visualize import (\n",
    "    GradCAMpp,\n",
    "    OcclusionSensitivity,\n",
    "    SmoothGrad,\n",
    "    GuidedBackpropGrad,\n",
    "    GuidedBackpropSmoothGrad,\n",
    ")\n",
    "from monai.utils import set_determinism\n",
    "from monai.apps import download_and_extract\n",
    "from urllib.request import urlretrieve\n",
    "import tempfile\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "set_determinism(0)\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2553f762",
   "metadata": {},
   "source": [
    "# Download and extract data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cad0eb4f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "data_path = os.path.join(root_dir, \"CatsAndDogs\")\n",
    "# check folder exists and contains 25,000 jpgs total\n",
    "if len(glob(os.path.join(data_path, \"**\", \"**\", \"*.jpg\"))) < 25000:\n",
    "    url = (\n",
    "        \"https://download.microsoft.com/download/3/E/1/\"\n",
    "        + \"3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip\"\n",
    "    )\n",
    "    md5 = \"e137a4507370d942469b6d267a24ea04\"\n",
    "    download_and_extract(url, output_dir=data_path, hash_val=md5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b6c637de",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num im cats: 500\n",
      "Num im dogs: 500\n",
      "Num images to be used: 1000\n"
     ]
    }
   ],
   "source": [
    "class Animals(Enum):\n",
    "    cat = 0\n",
    "    dog = 1\n",
    "\n",
    "\n",
    "def remove_non_rgb(data, max_num=None):\n",
    "    \"\"\"Some images are grayscale or rgba. For simplicity, remove them.\"\"\"\n",
    "    loader = LoadImaged(\"image\")\n",
    "    out = []\n",
    "    for i in data:\n",
    "        if os.path.getsize(i[\"image\"]) > 100:\n",
    "            im = loader(i)[\"image\"]\n",
    "            if im.ndim == 3 and im.shape[-1] == 3:\n",
    "                out.append(i)\n",
    "        if max_num is not None and len(out) == max_num:\n",
    "            return out\n",
    "    return out\n",
    "\n",
    "\n",
    "def get_data(animal, max_num=None):\n",
    "    files = glob(os.path.join(data_path, \"PetImages\", animal.name.capitalize(), \"*.jpg\"))\n",
    "    data = [{\"image\": i, \"label\": animal.value} for i in files]\n",
    "    shuffle(data)\n",
    "    data = remove_non_rgb(data, max_num)\n",
    "    return data\n",
    "\n",
    "\n",
    "# 500 of each class as this is sufficient\n",
    "cats, dogs = [get_data(i, max_num=500) for i in Animals]\n",
    "all_data = cats + dogs\n",
    "shuffle(all_data)\n",
    "\n",
    "print(f\"Num im cats: {len(cats)}\")\n",
    "print(f\"Num im dogs: {len(dogs)}\")\n",
    "print(f\"Num images to be used: {len(all_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa141df4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "batch_size = 20\n",
    "divisible_factor = 20\n",
    "transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(\"image\"),\n",
    "        EnsureChannelFirstd(\"image\"),\n",
    "        ScaleIntensityd(\"image\"),\n",
    "        Rotate90d(\"image\", k=3),\n",
    "        DivisiblePadd(\"image\", k=divisible_factor),\n",
    "    ]\n",
    ")\n",
    "\n",
    "ds = Dataset(all_data, transforms)\n",
    "dl = DataLoader(\n",
    "    ds,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    num_workers=10,\n",
    "    collate_fn=pad_list_data_collate,\n",
    "    drop_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e80f735b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def imshow(data):\n",
    "    nims = len(data)\n",
    "    if nims < 6:\n",
    "        shape = (1, nims)\n",
    "    else:\n",
    "        shape = int(np.floor(np.sqrt(nims))), int(np.ceil(np.sqrt(nims)))\n",
    "    fig, axes = plt.subplots(*shape, figsize=(20, 20))\n",
    "    axes = np.asarray(axes) if nims == 1 else axes\n",
    "    for d, ax in zip(data, axes.ravel()):\n",
    "        # channel last for matplotlib\n",
    "        im = np.moveaxis(d[\"image\"].detach().cpu().numpy(), 0, -1)\n",
    "        ax.imshow(im, cmap=\"gray\")\n",
    "        ax.set_title(Animals(d[\"label\"]).name, fontsize=25)\n",
    "        ax.axis(\"off\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cb69127",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Random images\n",
    "rand_idxs = np.random.choice(len(ds), size=12, replace=False)\n",
    "imshow([ds[i] for i in rand_idxs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "158ad896",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model = DenseNet121(spatial_dims=2, in_channels=3, out_channels=2, pretrained=True).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), 1e-5)\n",
    "use_amp = True\n",
    "label_dtype = torch.float16 if use_amp else torch.float32\n",
    "scaler = torch.GradScaler(\"cuda\") if use_amp else None\n",
    "\n",
    "\n",
    "def criterion(y_pred, y):\n",
    "    return torch.nn.functional.cross_entropy(y_pred, y, reduction=\"sum\")\n",
    "\n",
    "\n",
    "def get_num_correct(y_pred, y):\n",
    "    return (y_pred.argmax(dim=1) == y).sum().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d4fa4e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_epochs = 2\n",
    "for epoch in trange(max_epochs, desc=\"Epoch\"):\n",
    "    loss, acc = 0, 0\n",
    "    for data in dl:\n",
    "        inputs, labels = data[\"image\"].to(device), data[\"label\"].to(device)\n",
    "        optimizer.zero_grad()\n",
    "        with torch.autocast(\"cuda\", enabled=use_amp) if use_amp else nullcontext():\n",
    "            outputs = model(inputs)\n",
    "            train_loss = criterion(outputs, labels)\n",
    "            acc += get_num_correct(outputs, labels)\n",
    "        if use_amp:\n",
    "            scaler.scale(train_loss).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "        else:\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        loss += train_loss.item()\n",
    "    loss /= len(dl) * batch_size\n",
    "    acc /= len(dl) * batch_size\n",
    "    print(f\"Epoch {epoch+1}, loss: {loss:.3f}, acc: {acc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ebcf958",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with eval_mode(model):\n",
    "    y_pred = torch.tensor([], dtype=torch.float32, device=device)\n",
    "    y = torch.tensor([], dtype=torch.long, device=device)\n",
    "\n",
    "    for data in tqdm(dl):\n",
    "        images, labels = data[\"image\"].to(device), data[\"label\"].to(device)\n",
    "        with torch.autocast(\"cuda\", enabled=use_amp) if use_amp else nullcontext():\n",
    "            outputs = model(images).detach()\n",
    "        y_pred = torch.cat([y_pred, outputs], dim=0)\n",
    "        y = torch.cat([y, labels], dim=0)\n",
    "\n",
    "    y_pred = y_pred.argmax(dim=1)\n",
    "\n",
    "    cm = confusion_matrix(\n",
    "        y.cpu().numpy(),\n",
    "        y_pred.cpu().numpy(),\n",
    "        normalize=\"true\",\n",
    "    )\n",
    "    disp = ConfusionMatrixDisplay(\n",
    "        confusion_matrix=cm,\n",
    "        display_labels=[a.name for a in Animals],\n",
    "    )\n",
    "    _ = disp.plot(ax=plt.subplots(1, 1, facecolor=\"white\")[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2958a455",
   "metadata": {},
   "source": [
    "# Interpretability\n",
    "\n",
    "Now we compare our different saliency methods. Initially, the resulting can be tricky to decipher.\n",
    "\n",
    "## Occlusion sensitivity\n",
    "With occlusion sensitivity we iteratively block off part of the image and then we record the changes in certainty of the inferred class. This means that for instances where the network correctly infers the image type, we expect the certainty to drop as we occlude important parts of the image. Hence, for correct inference, blue parts of the image imply importance.\n",
    "\n",
    "This is also true when the network **incorrectly** infers the image; blue areas were important in inferring the given class.\n",
    "\n",
    "## GradCAM\n",
    "The user chooses a layer of the network that interests them and the gradient is calculated at this point. The chosen layer is typically towards the bottom of the network, as all the features have hopefully been extracted by this point. The images have been downsampled many times, and so the resulting images are linearly upsampled to match the size of the input image. As with occlusion sensitivity, blue parts of the image imply importance in the decision making process.\n",
    "\n",
    "## VanillaGrad\n",
    "`VanillaGrad` looks at the gradient of the image after putting it through the network. It is the basis for `SmoothGrad`, `GuidedBackpropGrad`, and `GuidedBackpropSmoothGrad`. For all of these methods, red areas imply importance in the decision making process.\n",
    "\n",
    "`VanillaGrad` is omitted in this notebook to save space but the user can add it in if interested.\n",
    "\n",
    "## SmoothGrad\n",
    " `SmoothGrad` repeatedly (default=25) adds noise to the input image and performs `VanillaGrad`. The results are then averaged.\n",
    "\n",
    "## GuidedBackpropGrad and GuidedBackpropSmoothGrad\n",
    "`GuidedBackpropGrad` and `GuidedBackpropSmoothGrad` extend upon `VanillaGrad` and `SmoothGrad`, respectively. They both store the gradients at certain points of the network, by default tthis is the `ReLU` layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d217917",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# for name, _ in model.named_modules(): print(name)\n",
    "target_layer = \"class_layers.relu\"\n",
    "gradcampp = GradCAMpp(model, target_layers=target_layer)\n",
    "occ_sens = OcclusionSensitivity(\n",
    "    model,\n",
    "    mask_size=32,\n",
    "    n_batch=batch_size,\n",
    "    overlap=0.5,\n",
    "    verbose=False,\n",
    ")\n",
    "smooth_grad = SmoothGrad(model, verbose=False)\n",
    "guided_vanilla = GuidedBackpropGrad(model)\n",
    "guided_smooth = GuidedBackpropSmoothGrad(model, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a960eb4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def saliency(model, d):\n",
    "    ims = []\n",
    "    titles = []\n",
    "    log_scales = []\n",
    "\n",
    "    img = torch.as_tensor(d[\"image\"])[None].to(device)\n",
    "    pred_logits = model(img)\n",
    "    pred_label = pred_logits.argmax(dim=1).item()\n",
    "    pred_prob = int(torch.nn.functional.softmax(pred_logits, dim=1)[0, pred_label].item() * 100)\n",
    "    # Image\n",
    "    ims.append(torch.moveaxis(img, 1, -1))\n",
    "    titles.append(f\"Pred: {Animals(pred_label).name} ({pred_prob}%)\")\n",
    "    log_scales.append(False)\n",
    "\n",
    "    # Occlusion sensitivity images\n",
    "    occ_map, _ = occ_sens(img)\n",
    "    ims.append(occ_map[0, pred_label][None])\n",
    "    titles.append(\"Occ. sens.\")\n",
    "    log_scales.append(False)\n",
    "\n",
    "    # GradCAM\n",
    "    res_cam_pp = gradcampp(x=img, class_idx=pred_label)[0]\n",
    "    ims.append(res_cam_pp)\n",
    "    titles.append(\"GradCAMpp\")\n",
    "    log_scales.append(False)\n",
    "\n",
    "    # other gradient-based approaches\n",
    "    for method, name in zip((smooth_grad, guided_vanilla, guided_smooth), (\"Smooth\", \"GuidedVa\", \"GuidedSm\")):\n",
    "        out = method(img)\n",
    "        out = torch.sum(out**2, dim=1) ** 0.5  # RGB -> scalar\n",
    "        ims.append(out)\n",
    "        titles.append(name)\n",
    "        log_scales.append(True)\n",
    "\n",
    "    return ims, titles, log_scales"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43947fe3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def add_im(im, title, log_scale, row, col, num_examples):\n",
    "    ax = axes[row, col] if num_examples > 1 else axes[col]\n",
    "    if isinstance(im, torch.Tensor):\n",
    "        im = im.detach().cpu()\n",
    "    if log_scale:\n",
    "        im_show = ax.imshow(im[0], cmap=\"jet\", norm=colors.LogNorm())\n",
    "        title += \" log\"\n",
    "    else:\n",
    "        im_show = ax.imshow(im[0], cmap=\"jet\")\n",
    "    ax.set_title(title, fontsize=25)\n",
    "    ax.axis(\"off\")\n",
    "    if col > 0:\n",
    "        fig.colorbar(im_show, ax=ax)\n",
    "\n",
    "\n",
    "def add_row(ims, titles, log_scales, row, axes, num_examples):\n",
    "    for col, (im, title, log_scale) in enumerate(zip(ims, titles, log_scales)):\n",
    "        if log_scale and im.min() < 0:\n",
    "            im -= im.min()\n",
    "        add_im(im, title, log_scale, row, col, num_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9030b675",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_examples = 5\n",
    "rand_data = np.random.choice(ds, replace=False, size=num_examples)\n",
    "tr = tqdm(rand_data)\n",
    "for row, d in enumerate(tr):\n",
    "    tr.set_description(f\"img shape: {d['image'].shape[1:]}\")\n",
    "    ims, titles, log_scales = saliency(model, d)\n",
    "    if row == 0:\n",
    "        num_cols = len(ims)\n",
    "        subplot_shape = [num_examples, num_cols]\n",
    "        figsize = [i * 5 for i in subplot_shape][::-1]\n",
    "        fig, axes = plt.subplots(*subplot_shape, figsize=figsize, facecolor=\"white\")\n",
    "    add_row(ims, titles, log_scales, row, axes, num_examples)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5c32b4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_data = [\n",
    "    {\"image\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/rabbit23.jpg\"},\n",
    "    {\"image\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test-cats-dogs2.jpg\"},\n",
    "    {\"image\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test-cats-dogs3.jpg\"},\n",
    "]\n",
    "\n",
    "\n",
    "def download_url(url):\n",
    "    fname = tempfile.NamedTemporaryFile(suffix=\".jpg\").name\n",
    "    return urlretrieve(url, fname)[0]\n",
    "\n",
    "\n",
    "new_transforms = Compose(\n",
    "    [\n",
    "        Lambdad(\"image\", download_url),\n",
    "        transforms,\n",
    "        Resized(\"image\", (320, 320)),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5ce767c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "num_examples = len(extra_data)\n",
    "for row, d in enumerate(tqdm(extra_data)):\n",
    "    d = new_transforms(d)\n",
    "    ims, titles, log_scales = saliency(model, d)\n",
    "    if row == 0:\n",
    "        num_cols = len(ims)\n",
    "        subplot_shape = [num_examples, num_cols]\n",
    "        figsize = [i * 5 for i in subplot_shape][::-1]\n",
    "        fig, axes = plt.subplots(*subplot_shape, figsize=figsize, facecolor=\"white\")\n",
    "    add_row(ims, titles, log_scales, row, axes, num_examples)\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
